Conversation
yeandy
left a comment
There was a problem hiding this comment.
I'm wondering how many of these files we need?
primus/backends/maxtext/input_pipeline/_hf_data_processing.pyprimus/backends/maxtext/input_pipeline/custom_packed_batch.py(I see this is deleted)primus/backends/maxtext/layers/attention_op.pyprimus/backends/maxtext/layers/attentions.py(I see this is deleted)primus/backends/maxtext/metric_logger.pyprimus/backends/maxtext/train.pyprimus/backends/maxtext/train_utils.py
I think they were added in the past for the purposes of patching. @amd-fuyuajin do you know if these are getting patched into the MaxText codebase when you run the training? Even if it is, it might be the same code as what is found in rocm/jax-training:maxtext-v26.1 actually. @llying-001 might know best.
I updated these files in the Primus repo to stay aligned with the yeandy/update-patches-scaling-patch-v2-checkpoint-restore branch in ROCm/maxtext. |
- Add timestamp to log filenames to prevent overwriting across runs - Move tee logging outside the inline script to capture consolidated multi-node output in a single log file - Make --nodelist conditional via NODE_LIST env variable
- set TF_CPP_MIN_LOG_LEVEL=2. Without this setting, error occurs at the end when all training steps complete. - XLA_FLAGS is case sensitive. Corrected a few values.
- detect backend framework in `primus-cli-direct.sh`. Install JAX dependencies - If using AINIC (setting USING_AINIC=1), `03_enable_ainic.sh` will run. The `LD_LIBRARY_PATH` is modified to make sure libraries are correctly loaded for JAX/MaxText. - Set XLA_PYTHON_CLIENT_MEM_FRACTION=.93 to avoid HSA_STATUS_ERROR_OUT_OF_RESOURCES error during multi-node training - Corrected some XLA_FLAGS. It is case sensitive. Values `true` and `false` do not need to be capitalized. - set TF_CPP_MIN_LOG_LEVEL=2 to suppress the error messages at the end of JAX/MaxText training Here is an example to launch JAX/MaxText traing on two nodes. `./primus-cli --config runner/maxtext-test.yaml slurm srun -N 2 -- train pretrain --config examples/maxtext/configs/MI355X/llama2_7B-pretrain.yaml`
Problem: when apt install linux-headers-"$(uname -r)", it was resolved to wrong version number on some nodes, and caused "package not found" error. Solution: remove it from the package install list. It does not affect the performance.
1. added examples for using AINIC in training 2. added more examples for running preflight 3. updated arguments format for benchmark gemm command. The script was changed, but document was not updated.
2e31891 to
095b267
Compare
There was a problem hiding this comment.
Pull request overview
This PR adds comprehensive support for JAX/MaxText backend testing and multi-node training capabilities, including AINIC network integration, improved checkpointing, and various model architecture enhancements.
Changes:
- Updated MaxText submodule to a newer commit
- Added AINIC configuration support with proper environment variable setup and library path ordering
- Enhanced MaxText backend with improved checkpointing, attention mechanisms, and decoder layer implementations
- Refactored dependency installation to detect framework type and install appropriate requirements
Reviewed changes
Copilot reviewed 34 out of 35 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| third_party/maxtext | Updated MaxText submodule reference to newer commit |
| runner/use_ainic.yaml | New configuration file for AINIC network setup with container options |
| runner/primus-cli-direct.sh | Added framework detection logic to install correct dependencies (JAX vs PyTorch) |
| runner/helpers/hooks/train/pretrain/maxtext/prepare.py | Removed problematic linux-headers package, adjusted memory limits and XLA flags |
| runner/helpers/hooks/03_enable_ainic.sh | Fixed LD_LIBRARY_PATH ordering to append instead of prepend paths |
| runner/.primus.yaml | Uncommented InfiniBand device for AINIC support |
| requirements-jax.txt | Simplified to core dependencies only |
| primus/pretrain.py | Enhanced MaxText path detection to support src subdirectory |
| primus/modules/trainer/maxtext/pre_trainer.py | Extended patching to include initialization, checkpointing, config types, and decoder layers |
| primus/configs/modules/maxtext/trainer_base.yaml | Updated configuration with new parameters and removed deprecated options |
| primus/configs/models/maxtext/llama3.1_405B.yaml | New model configuration for Llama 3.1 405B |
| primus/backends/maxtext/train_utils.py | Refactored emergency checkpoint logic and updated to use max_num_checkpoints_to_keep |
| primus/backends/maxtext/train.py | Major refactor with barrier synchronization, improved error handling, and new training features |
| primus/backends/maxtext/metric_logger.py | Updated to use MetadataKey enum constants |
| primus/backends/maxtext/max_utils.py | Added JAX distributed initialization functions for GPU/CPU/TPU |
| primus/backends/maxtext/layers/moe.py | Updated MoE layer to pass bias parameters |
| primus/backends/maxtext/layers/mixtral.py | New Primus-specific Mixtral decoder layer implementation |
| primus/backends/maxtext/layers/mistral.py | New Primus-specific Mistral decoder layer implementation |
| primus/backends/maxtext/layers/llama2.py | New Primus-specific Llama2 decoder layer implementation |
| primus/backends/maxtext/layers/gemma2.py | New Primus-specific Gemma2 decoder layer implementation |
| primus/backends/maxtext/layers/gemma.py | New Primus-specific Gemma decoder layer implementation |
| primus/backends/maxtext/layers/attentions.py | Removed entire attention implementation file |
| primus/backends/maxtext/layers/attention_op.py | Enhanced CUDNN Flash Attention with packing and context parallelism support |
| primus/backends/maxtext/input_pipeline/custom_packed_batch.py | Removed custom packing implementation |
| primus/backends/maxtext/input_pipeline/_hf_data_processing.py | Updated to use grain's native packing and added instruction format conversion |
| primus/backends/maxtext/configs/types.py | New Primus-specific MaxText config with WandB and Turbo support |
| primus/backends/maxtext/checkpointing.py | Added comprehensive checkpoint loading logic with single replica support |
| examples/run_slurm_pretrain.sh | Added NODE_LIST support and timestamped log files |
| examples/run_pretrain.sh | Reorganized AINIC configuration and updated XLA flags |
| examples/run_local_pretrain.sh | Updated default Docker image to maxtext-v26.1 |
| examples/maxtext/configs/MI355X/mixtral_8x7B-pretrain.yaml | Reduced batch size from 12 to 11 |
| examples/maxtext/configs/MI355X/llama3.1_405B-pretrain.yaml | New training configuration for Llama 3.1 405B model |
| examples/maxtext/configs/MI300X/mixtral_8x7B-pretrain.yaml | Updated remat policy |
| docs/cli/PRIMUS-CLI-GUIDE.md | Updated documentation with AINIC configuration examples and corrected command syntax |
Comments suppressed due to low confidence (4)
runner/primus-cli-direct.sh:1
- Array index arithmetic should use proper bash syntax. The expression
$((i+1))correctly increments i, but when used inside array subscript it should be written as${args[i+1]}without the extra parentheses, or the current form needs validation that i+1 is within array bounds before access.
runner/primus-cli-direct.sh:1 - Python code embedded in bash script should properly close file handles. The
open('$cfg_path')should be wrapped in a context manager usingwith open('$cfg_path') as f: cfg = yaml.safe_load(f)to ensure the file is properly closed even if an exception occurs.
primus/backends/maxtext/max_utils.py:1 - Operator precedence issue: the condition mixes
orandandwithout parentheses. Due to operator precedence, this evaluates as(self.wandb_save_dir is None) or (self.wandb_save_dir == '' and self.base_output_directory), which may not be the intended logic. Add explicit parentheses:if (self.wandb_save_dir is None or self.wandb_save_dir == '') and self.base_output_directory:
###############################################################################
primus/backends/maxtext/max_utils.py:1
- Same operator precedence issue as above. Should be:
if (self.wandb_exp_name is None or self.wandb_exp_name == '') and self.run_name:
###############################################################################
accept copilot commit suggestion Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
| if [ "${BACKEND:-}" == "MaxText" ]; then | ||
| # ------- RCCL/NCCL IB Tuning ------- | ||
| export IONIC_LOCKFREE=all | ||
| export NCCL_GDR_COPY_ENABLE=1 | ||
| export NCCL_GDR_FLUSH_DISABLE=1 | ||
| export NCCL_IB_ECE_ENABLE=0 | ||
| export NCCL_IB_FIFO_TC=184 | ||
| export NCCL_IB_GID_INDEX=1 | ||
| export NCCL_IB_PCI_RELAXED_ORDERING=1 | ||
| export NCCL_IB_TC=96 | ||
| export NCCL_IB_USE_INLINE=1 | ||
| export NCCL_IGNORE_CPU_AFFINITY=1 | ||
| export NCCL_PXN_DISABLE=0 | ||
| export NET_OPTIONAL_RECV_COMPLETION=1 | ||
| export RCCL_GDR_FLUSH_GPU_MEM_NO_RELAXED_ORDERING=0 | ||
| export RCCL_LL128_FORCE_ENABLE=1 | ||
| else | ||
| export ANP_HOME_DIR=${ANP_HOME_DIR:-"/opt/amd-anp"} | ||
| export RCCL_HOME_DIR=${RCCL_HOME_DIR:-"/opt/rccl"} | ||
| export MPI_HOME_DIR=${MPI_HOME_DIR:-"/opt/ompi"} | ||
| export NCCL_NET_PLUGIN=librccl-anp.so | ||
|
|
||
| LOG_INFO_RANK0 "RCCL_HOME_DIR: $RCCL_HOME_DIR" | ||
| LOG_INFO_RANK0 "ANP_HOME_DIR: $ANP_HOME_DIR" | ||
| LOG_INFO_RANK0 "MPI_HOME_DIR: $MPI_HOME_DIR" | ||
|
|
||
| # unset NCCL_IB_GID_INDEX | ||
| export NCCL_IB_GID_INDEX=1 | ||
| # export NCCL_IB_ROCE_VERSION_NUM=2 | ||
| export NCCL_MAX_P2P_CHANNELS=56 | ||
| export NCCL_IB_TC=104 | ||
| export NCCL_IB_FIFO_TC=192 | ||
| export NET_OPTIONAL_RECV_COMPLETION=1 | ||
| export NCCL_IB_USE_INLINE=1 | ||
| export RCCL_GDR_FLUSH_GPU_MEM_NO_RELAXED_ORDERING=0 | ||
| export NCCL_GDR_FLUSH_DISABLE=1 | ||
| export NCCL_DMABUF_ENABLE=0 | ||
| export NCCL_IGNORE_CPU_AFFINITY=1 | ||
| export NCCL_IB_QPS_PER_CONNECTION=1 | ||
|
|
||
| export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:/usr/lib/x86_64-linux-gnu/libibverbs:${RCCL_HOME_DIR}/build/release:${ANP_HOME_DIR}/build:${MPI_HOME_DIR}/lib:$LD_LIBRARY_PATH | ||
| fi |
There was a problem hiding this comment.
Can you explain why we need to have different flags (NCCL_IB_TC, NCCL_IB_FIFO_TC) when using MaxText backend or not using MaxText backend? I think these flags are more related to cluster settings, right?
There was a problem hiding this comment.
@llying-001 can explain this better. I did not change any of this part.
There was a problem hiding this comment.
Can you explain why we need to have different flags (NCCL_IB_TC, NCCL_IB_FIFO_TC) when using MaxText backend or not using MaxText backend? I think these flags are more related to cluster settings, right?
I extracted these env flags for MaxText backend from https://github.com/ROCm/MAD/blob/develop/scripts/jax-maxtext/jax_maxtext_multinode_benchmark.sh#L305. They are actually related to the cluster instead of backend. Are the env flags in jax_maxtext_multinode_benchmark.sh configured for Vultr cluster? @yeandy
For the Megatron/Titan backend, which cluster are the env flags in run_pretrain.sh configured for? @zhenhuang12
It would be great if we could unify them.
| apt install jq dpkg-dev kmod xz-utils -y | ||
| apt install libibverbs-dev ibverbs-utils infiniband-diags -y | ||
| apt install rdma-core librdmacm-dev libibverbs-dev libibumad-dev -y | ||
| LOG_INFO_RANK0 "========== Install IB required packages for Jax/MaxText Done ==========" |
There was a problem hiding this comment.
These are not for JAX/MaxText libraries per-se, but rather to add missing dependencies not found in the public docker (like rocm/jax-training:maxtext-v26.1), right? @amd-fuyuajin
We don't need to do this for megatron or torchtitan jobs? Or is this already installed in those dockers? @wenxie-amd
There was a problem hiding this comment.
These packages are mainly related to InfiniBand/RDMA libraries. I see they are only installed when NNODES > 1 (line 440). They probably provide networking stack for distributed training. Again, @llying-001 added this and can explain better.
There was a problem hiding this comment.
Yes, these packages are dependencies required for REBUILD_BNXT that are missing in the public JAX docker image (e.g., rocm/jax-training:maxtext-v26.1), but they are already installed in the Torch docker image (e.g., rocm/primus:v26.1)
No description provided.